{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IDdZSPcLtKx4"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-g5By3P4tavy"
      },
      "outputs": [],
      "source": [
        "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS, \n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vpaLrN0mteAS"
      },
      "source": [
        "# 使用 TF-Hub 对孟加拉语文章进行分类"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/hub/tutorials/bangla_article_classifier\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/bangla_article_classifier.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/bangla_article_classifier.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/bangla_article_classifier.ipynb\">{img1下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GhN2WtIrBQ4y"
      },
      "source": [
        "小心：除了使用 pip 安装 Python 软件包外，此笔记本还使用 `sudo apt install` 安装系统软件包：`unzip`。\n",
        "\n",
        "此 Colab 演示了如何使用 [Tensorflow Hub](https://tensorflow.google.cn/hub/) 对非英语/本地语言进行文本分类。在这里，我们选择[孟加拉语](https://en.wikipedia.org/wiki/Bengali_language)作为本地语言并使用预训练的单词嵌入向量解决多类分类任务，在这个任务中我们将孟加拉语的新闻文章分为 5 类。针对孟加拉语进行预训练的嵌入向量来自 [FastText](https://fasttext.cc/docs/en/crawl-vectors.html)，这是一个由 Facebook 创建的库，其中包含 157 种语言的预训练单词向量。\n",
        "\n",
        "我们将使用 TF-Hub 的预训练嵌入向量导出程序先将单词嵌入向量转换为文本嵌入向量模块，然后使用该模块通过 [tf.keras](https://tensorflow.google.cn/api_docs/python/tf/keras)（Tensorflow 的高级用户友好 API）训练分类器来构建深度学习模型。即使我们在这里使用 fastText 嵌入向量，您也可以导出任何通过其他任务预训练的其他嵌入向量，并使用 Tensorflow Hub 快速获得结果。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q4DN769E2O_R"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Vt-StAAZguA"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "# https://github.com/pypa/setuptools/issues/1694#issuecomment-466010982\n",
        "pip install gdown --no-use-pep517"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WcBA19FlDPZO"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "sudo apt-get install -y unzip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zSeyZMq-BYsu"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "import gdown\n",
        "import numpy as np\n",
        "from sklearn.metrics import classification_report\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9FB7gLU4F54l"
      },
      "source": [
        "# 数据集\n",
        "\n",
        "我们将使用 [BARD](https://www.researchgate.net/publication/328214545_BARD_Bangla_Article_Classification_Using_a_New_Comprehensive_Dataset)（孟加拉语文章数据集），内含从不同孟加拉语新闻门户收集的约 3,76,226 篇文章，并标记为 5 个类别：经济、国内、国际、体育和娱乐。我们从 Google 云端硬盘下载这个文件，此 ([bit.ly/BARD_DATASET](bit.ly/BARD_DATASET)) 链接指向[此](https://github.com/tanvirfahim15/BARD-Bangla-Article-Classifier) GitHub 仓库。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zdQrL_rwa-1K"
      },
      "outputs": [],
      "source": [
        "gdown.download(\n",
        "    url='https://drive.google.com/uc?id=1Ag0jd21oRwJhVFIBohmX_ogeojVtapLy',\n",
        "    output='bard.zip',\n",
        "    quiet=True\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P2YW4GGa9Y5o"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "unzip -qo bard.zip"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "js75OARBF_B8"
      },
      "source": [
        "# 将预训练的单词向量导出到 TF-Hub 模块"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-uAicYA6vLsf"
      },
      "source": [
        "TF-Hub 提供了一些方便的脚本将单词嵌入向量转换为 TF-Hub 文本嵌入向量模块，详见[这里](https://github.com/tensorflow/hub/tree/master/examples/text_embeddings_v2)。要使模块适用于孟加拉语或其他语言，我们只需将单词嵌入向量 .txt 或 .vec 文件下载到与 export_v2.py 相同的目录中，然后运行脚本。\n",
        "\n",
        "导出程序会读取嵌入向量，并将其导出到 Tensorflow [SavedModel](https://tensorflow.google.cn/beta/guide/saved_model)。SavedModel 包含完整的 TensorFlow 程序，其中包括权重和计算图。TF-Hub 可以将 SavedModel 作为[模块](https://tensorflow.google.cn/hub/api_docs/python/hub/Module)进行加载，我们将用它来构建文本分类模型。由于我们使用 tf.keras 来构建模型，因此我们将使用 [hub.KerasLayer](https://tensorflow.google.cn/hub/api_docs/python/hub/KerasLayer)，它为 Hub 模块提供用作 Keras 层的封装容器。\n",
        "\n",
        "首先，我们从 fastText 获得单词嵌入向量，并从 TF-Hub [仓库](https://github.com/tensorflow/hub)获得嵌入向量导出程序。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5DY5Ze6pO1G5"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "curl -O https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz\n",
        "curl -O https://raw.githubusercontent.com/tensorflow/hub/master/examples/text_embeddings_v2/export_v2.py\n",
        "gunzip -qf cc.bn.300.vec.gz --k"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PAzdNZaHmdl1"
      },
      "source": [
        "然后，我们在嵌入向量文件上运行导出程序脚本。由于 fastText 嵌入向量具有标题行并且相当大（转换为模块后，孟加拉语大约为 3.3 GB），因此我们忽略第一行，仅将前 100, 000 个词例导入文本嵌入向量模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tkv5acr_Q9UU"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "python export_v2.py --embedding_file=cc.bn.300.vec --export_path=text_module --num_lines_to_ignore=1 --num_lines_to_use=100000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k9WEpmedF_3_"
      },
      "outputs": [],
      "source": [
        "module_path = \"text_module\"\n",
        "embedding_layer = hub.KerasLayer(module_path, trainable=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fQHbmS_D4YIo"
      },
      "source": [
        "文本嵌入向量模块以一维字符串张量中的句子批次作为输入，并输出与句子相对应的形状 (batch_size, embedding_dim) 的嵌入向量。它通过按空格拆分来对输入进行预处理。我们使用 `sqrtn` 组合程序（请参阅[此处](https://tensorflow.google.cn/api_docs/python/tf/nn/embedding_lookup_sparse)）将单词嵌入向量组合到句子嵌入向量。为了演示，我们传递一个孟加拉语单词的列表作为输入，并获得相应的嵌入向量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z1MBnaBUihWn"
      },
      "outputs": [],
      "source": [
        "embedding_layer(['বাস', 'বসবাস', 'ট্রেন', 'যাত্রী', 'ট্রাক']) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4KY8LiFOHmcd"
      },
      "source": [
        "# 转换为 TensorFlow 数据集\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pNguCDNe6bvz"
      },
      "source": [
        "由于数据集确实很大，因此我们使用生成器通过 [Tensorflow 数据集](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset)的功能在运行时批量生成样本，而不是将整个数据集加载到内存中。同时，数据集还非常不平衡，因此在使用生成器之前，我们将打乱数据集的顺序。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bYv6LqlEChO1"
      },
      "outputs": [],
      "source": [
        "dir_names = ['economy', 'sports', 'entertainment', 'state', 'international']\n",
        "\n",
        "file_paths = []\n",
        "labels = []\n",
        "for i, dir in enumerate(dir_names):\n",
        "  file_names = [\"/\".join([dir, name]) for name in os.listdir(dir)]\n",
        "  file_paths += file_names\n",
        "  labels += [i] * len(os.listdir(dir))\n",
        "  \n",
        "np.random.seed(42)\n",
        "permutation = np.random.permutation(len(file_paths))\n",
        "\n",
        "file_paths = np.array(file_paths)[permutation]\n",
        "labels = np.array(labels)[permutation]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8b-UtAP5TL-W"
      },
      "source": [
        "打乱顺序后，我们可以查看标签在训练和验证样本中的分布。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mimhWVSzzAmS"
      },
      "outputs": [],
      "source": [
        "train_frac = 0.8\n",
        "train_size = int(len(file_paths) * train_frac)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4BNXFrkotAYu"
      },
      "outputs": [],
      "source": [
        "# plot training vs validation distribution\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.hist(labels[0:train_size])\n",
        "plt.title(\"Train labels\")\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.hist(labels[train_size:])\n",
        "plt.title(\"Validation labels\")\n",
        "plt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RVbHb2I3TUNA"
      },
      "source": [
        "要使用生成器创建[数据集](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset)，我们首先编写一个生成器函数，该函数从 file_paths 读取文章，从标签数组中读取标签，并在每个步骤生成一个训练样本。我们将此生成器函数传递到 [tf.data.Dataset.from_generator](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#from_generator) 方法，并指定输出类型。每个训练样本都是一个元组，其中包含 tf.string 数据类型的文章和独热编码标签。我们使用 [`skip`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#skip) 和 [`take`](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#take) 方法以 80-20 的比例将数据集拆分为训练集和验证集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eZRGTzEhUi7Q"
      },
      "outputs": [],
      "source": [
        "def load_file(path, label):\n",
        "    return tf.io.read_file(path), label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2g4nRflB7fbF"
      },
      "outputs": [],
      "source": [
        "def make_datasets(train_size):\n",
        "  batch_size = 256\n",
        "\n",
        "  train_files = file_paths[:train_size]\n",
        "  train_labels = labels[:train_size]\n",
        "  train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))\n",
        "  train_ds = train_ds.map(load_file).shuffle(5000)\n",
        "  train_ds = train_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "  test_files = file_paths[train_size:]\n",
        "  test_labels = labels[train_size:]\n",
        "  test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))\n",
        "  test_ds = test_ds.map(load_file)\n",
        "  test_ds = test_ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)\n",
        "\n",
        "\n",
        "  return train_ds, test_ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8PuuN6el8tv9"
      },
      "outputs": [],
      "source": [
        "train_data, validation_data = make_datasets(train_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MrdZI6FqPJNP"
      },
      "source": [
        "# 模型训练和评估"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jgr7YScGVS58"
      },
      "source": [
        "由于我们已经在模块周围添加了封装容器，使其可以像 Keras 中的任何其他层一样使用，因此我们可以创建一个小的[序贯](https://tensorflow.google.cn/api_docs/python/tf/keras/Sequential)模型，此模型是层的线性堆叠。我们可以像使用任何其他层一样，使用 `model.add` 添加文本嵌入向量模块。我们通过指定损失和优化器来编译模型，并对其进行 10 个周期的训练。`tf.keras` API 可以将 TensorFlow 数据集作为输入进行处理，因此我们可以将数据实例传递给用于模型训练的拟合方法。由于我们使用的是生成器函数，tf.data 将负责生成样本、对其进行批处理，并将其馈送给模型。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WhCqbDK2uUV5"
      },
      "source": [
        "## 模型"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nHUw807XPPM9"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Input(shape=[], dtype=tf.string),\n",
        "    embedding_layer,\n",
        "    tf.keras.layers.Dense(64, activation=\"relu\"),\n",
        "    tf.keras.layers.Dense(16, activation=\"relu\"),\n",
        "    tf.keras.layers.Dense(5),\n",
        "  ])\n",
        "  model.compile(loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      optimizer=\"adam\", metrics=['accuracy'])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5J4EXJUmPVNG"
      },
      "outputs": [],
      "source": [
        "model = create_model()\n",
        "# Create earlystopping callback\n",
        "early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZZ7XJLg2u2No"
      },
      "source": [
        "## 训练"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OoBkN2tAaXWD"
      },
      "outputs": [],
      "source": [
        "history = model.fit(train_data, \n",
        "                    validation_data=validation_data, \n",
        "                    epochs=5, \n",
        "                    callbacks=[early_stopping_callback])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XoDk8otmMoT7"
      },
      "source": [
        "## 评估"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G5ZRKGOsXEh4"
      },
      "source": [
        "我们可以使用由 `fit` 方法返回的 `history` 对象（包含每个周期的损失和准确率值）来可视化训练和验证数据的准确率和损失曲线。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V6tOnByIOeGn"
      },
      "outputs": [],
      "source": [
        "# Plot training &amp; validation accuracy values\n",
        "plt.plot(history.history['accuracy'])\n",
        "plt.plot(history.history['val_accuracy'])\n",
        "plt.title('Model accuracy')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.xlabel('Epoch')\n",
        "plt.legend(['Train', 'Test'], loc='upper left')\n",
        "plt.show()\n",
        "\n",
        "# Plot training &amp; validation loss values\n",
        "plt.plot(history.history['loss'])\n",
        "plt.plot(history.history['val_loss'])\n",
        "plt.title('Model loss')\n",
        "plt.ylabel('Loss')\n",
        "plt.xlabel('Epoch')\n",
        "plt.legend(['Train', 'Test'], loc='upper left')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D54IXLqcG8Cq"
      },
      "source": [
        "## 预测\n",
        "\n",
        "我们可以获得验证数据的预测并检查混淆矩阵，以查看模型在 5 个类中的性能。`predict` 方法返回每个类的概率的 N 维数组后，我们使用 `np.argmax` 将其转换为类标签。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dptEywzZJk4l"
      },
      "outputs": [],
      "source": [
        "y_pred = model.predict(validation_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Dzeml6Pk0ub"
      },
      "outputs": [],
      "source": [
        "y_pred = np.argmax(y_pred, axis=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T4M3Lzg8jHcB"
      },
      "outputs": [],
      "source": [
        "samples = file_paths[0:3]\n",
        "for i, sample in enumerate(samples):\n",
        "  f = open(sample)\n",
        "  text = f.read()\n",
        "  print(text[0:100])\n",
        "  print(\"True Class: \", sample.split(\"/\")[0])\n",
        "  print(\"Predicted Class: \", dir_names[y_pred[i]])\n",
        "  f.close() "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PlDTIpMBu6h-"
      },
      "source": [
        "## 比较性能\n",
        "\n",
        "现在，我们可以从 `labels` 获得验证数据的正确标签，并与我们的预测进行比较，以获得 [classification_report](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mqrERUCS1Xn7"
      },
      "outputs": [],
      "source": [
        "y_true = np.array(labels[train_size:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NX5w-NuTKuVP"
      },
      "outputs": [],
      "source": [
        "print(classification_report(y_true, y_pred, target_names=dir_names))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p5e9m3bV6oXK"
      },
      "source": [
        "我们还可以将模型的性能与原始[论文](https://www.researchgate.net/publication/328214545_BARD_Bangla_Article_Classification_Using_a_New_Comprehensive_Dataset)中报告的精度为 0.96 的发布结果进行比较。原作者描述了在数据集上完成的许多预处理步骤，例如删除标点和数字、去除前 25 个最常见的停用词等。正如我们在 classification_report 中所见，在仅训练了 5 个周期而没有进行任何预处理的情况下，我们也获得了 0.96 的精度和准确率！\n",
        "\n",
        "在此示例中，当我们从嵌入向量模块创建 Keras 层时，我们设置了 `trainable=False`，这意味着训练期间不会更新嵌入向量权重。请尝试将此设置为 True，使用此数据集仅用 2 个周期即可达到 97% 的准确率。 "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "IDdZSPcLtKx4"
      ],
      "name": "bangla_article_classifier.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
